import json
import algorithm

class DifferentialAttack:
    def __init__(self):
        self.SPNenc = algorithm.Encrpt() # 基本的加密算法
        self.differ_table = [[0 for i in range(16)] for j in range(16)]
        json_obj = json.load(open('Plain_Cipher_pairs.json', 'r'))
        self.PC_pairs = json_obj['PC_pairs']
        self.PC_pairs_chosen = [] # 根据PC_pairs选择的明密文对，使得对应位置的delta P 满足差异特征
        # differ_characteristic:差异特征，[Delta P, Delta U]
        self.differ_characteristic = [0b0000101100000000, 0b0000011000000110]
        self.tpk_count = [0 for _ in range(256)] # 每个target partial subkey的计数器

    # 差分分布表的生成
    def gen_differ_table(self):
        # 论文中的Table 7
        for input_sum in range(16):
            for x_0 in range(16):
                x_1 = x_0 ^ input_sum
                y_0 = self.SPNenc.sbox(x_0)
                y_1 = self.SPNenc.sbox(x_1)
                output_sum = y_1 ^ y_0
                self.differ_table[input_sum][output_sum] += 1

    def dec_partial(self,ciphertext,rk):
        # 对密文进行部分解密
        beforeMixKey = ciphertext ^ rk
        beforeSbox = self.SPNenc.sbox_inv(beforeMixKey)
        return beforeSbox

    def choose_PCpairs(self):
        # 选择合适的明密文对
        for plain, _ in self.PC_pairs:
            plain_chosen = plain ^ self.differ_characteristic[0]
            cipher_chosen = self.SPNenc.enc(plain_chosen)
            self.PC_pairs_chosen.append([plain_chosen,cipher_chosen])


    def differ_attack(self):
        # 差分攻击生成第五轮目标子密钥
        self.choose_PCpairs()

        # 筛选掉不符合特征的PCpairs
        correct_ind = []
        for i in range(10000):
            cipher0 = self.PC_pairs[i][1]
            cipher1 = self.PC_pairs_chosen[i][1]
            if (cipher0 ^ cipher1) & 0xf0f0 == 0:
                correct_ind.append(i)
        correct_pair_len = len(correct_ind)



        for tpk in range(256):
            subkey = (((tpk >> 4) & 0xf) << 8) ^ (tpk & 0xf)
            for i in range(correct_pair_len):
                U0 = self.dec_partial(self.PC_pairs[correct_ind[i]][1],subkey)
                U1 = self.dec_partial(self.PC_pairs_chosen[correct_ind[i]][1],subkey)
                if (U0 ^ U1) & 0x0f0f == self.differ_characteristic[1]:
                    self.tpk_count[tpk] += 1

        # 寻找count对应偏差最大的子密钥作为攻击结果
        max, max_ind = 0,0
        for tpk in range(256):
            # self.tpk_count[tpk] /= 10000
            if max < self.tpk_count[tpk]:
                max = self.tpk_count[tpk]
                max_ind = tpk
        print("目标部分子密钥分别为：",hex(max_ind>>4),hex(max_ind & 0xf))
        print("实际第五轮子密钥为：",hex(self.SPNenc.rk[4]))



if __name__ == "__main__":
    attack = DifferentialAttack()

    # 测试生成线性近似表
    attack.gen_differ_table()
    print("差分分布表的生成结果：")
    for row in attack.differ_table:
        print(row)

    # 测试线性攻击能否恢复target partial key
    attack.differ_attack()
